{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3fef5a94-b06f-48ae-90e5-2f919d3352bd",
   "metadata": {},
   "source": [
    "This notebook shows how to ingest and search images using ColPali with Elasticsearch. Read our accompanying blog post on [ColPali in Elasticsearch](elastiacsearch-colpali-visual-document-search) for more context on this notebook. \n",
    "\n",
    "We will be using images from the [ViDoRe benchmark](https://huggingface.co/collections/vidore/vidore-benchmark-667173f98e70a1c0fa4db00d) as example data. \n",
    "\n",
    "The URL and API key for your Elasticsearch cluster are expected in a file `elastic.env` in this format: \n",
    "```\n",
    "ELASTIC_HOST=<cluster-url>\n",
    "ELASTIC_API_KEY=<api-key>\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a1610e61-fbfe-4d7f-9109-601a0ccd0129",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -r requirements.txt\n",
    "from IPython.display import clear_output\n",
    "\n",
    "clear_output()  # for less space usage."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aec6865f-dc2d-4242-a568-2fbf94cf2201",
   "metadata": {},
   "source": [
    "First we load the sample data from huggingface and save it to disk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "baf63024-c058-4e1c-a170-6730bf2f2704",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-02T09:16:41.680203Z",
     "start_time": "2025-03-02T09:14:00.648234Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6c0aa31eaa8546478c3e48fcc206dbd3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving images to disk:   0%|          | 0/500 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "from tqdm.notebook import tqdm\n",
    "import os\n",
    "\n",
    "DATASET_NAME = \"vidore/infovqa_test_subsampled\"\n",
    "DOCUMENT_DIR = \"searchlabs-colpali\"\n",
    "\n",
    "os.makedirs(DOCUMENT_DIR, exist_ok=True)\n",
    "dataset = load_dataset(DATASET_NAME, split=\"test\")\n",
    "\n",
    "for i, row in enumerate(tqdm(dataset, desc=\"Saving images to disk\")):\n",
    "    image = row.get(\"image\")\n",
    "    image_name = f\"image_{i}.jpg\"\n",
    "    image_path = os.path.join(DOCUMENT_DIR, image_name)\n",
    "    image.save(image_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da958778-42b3-438a-992c-172097d8d464",
   "metadata": {},
   "source": [
    "Here we load the ColPali model and define functions to generate vectors from images and text. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "32aad27e-afe7-4bbf-ad13-1be82c917a70",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-02T09:16:41.681836Z",
     "start_time": "2025-03-02T09:14:10.977348Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0e608222306e4cd1a1b933ed248e74f8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "from PIL import Image\n",
    "from colpali_engine.models import ColPali, ColPaliProcessor\n",
    "\n",
    "model_name = \"vidore/colpali-v1.3\"\n",
    "model = ColPali.from_pretrained(\n",
    "    \"vidore/colpali-v1.3\",\n",
    "    torch_dtype=torch.float32,\n",
    "    device_map=\"mps\",  # \"mps\" for Apple Silicon, \"cuda\" if available, \"cpu\" otherwise\n",
    ").eval()\n",
    "\n",
    "col_pali_processor = ColPaliProcessor.from_pretrained(model_name)\n",
    "\n",
    "\n",
    "def create_col_pali_image_vectors(image_path: str) -> list:\n",
    "    batch_images = col_pali_processor.process_images([Image.open(image_path)]).to(\n",
    "        model.device\n",
    "    )\n",
    "\n",
    "    with torch.no_grad():\n",
    "        return model(**batch_images).tolist()[0]\n",
    "\n",
    "\n",
    "def create_col_pali_query_vectors(query: str) -> list:\n",
    "    queries = col_pali_processor.process_queries([query]).to(model.device)\n",
    "    with torch.no_grad():\n",
    "        return model(**queries).tolist()[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d12ea156-4e2b-4b84-983e-d0e63c9a6178",
   "metadata": {},
   "source": [
    "This is where we are going over all our images and creating our multi-vectors with the ColPali model. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fcf55e15-6c4a-4003-b929-aab2931c2389",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-02T09:16:41.682259Z",
     "start_time": "2025-03-02T09:14:22.244797Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "76b9d003d76d49d1b82b25d124deddeb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Create ColPali Vectors:   0%|          | 0/500 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved 500 vector entries to disk\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import pickle\n",
    "\n",
    "images = [os.path.join(DOCUMENT_DIR, f) for f in os.listdir(DOCUMENT_DIR)]\n",
    "file_to_multi_vectors = {}\n",
    "\n",
    "for image_path in tqdm(images, desc=\"Create ColPali Vectors\"):\n",
    "    file_name = os.path.basename(image_path)\n",
    "    vectors_f32 = create_col_pali_image_vectors(image_path)\n",
    "    file_to_multi_vectors[file_name] = vectors_f32\n",
    "\n",
    "with open(\"col_pali_vectors.pkl\", \"wb\") as f:\n",
    "    pickle.dump(file_to_multi_vectors, f)\n",
    "\n",
    "print(f\"Saved {len(file_to_multi_vectors)} vector entries to disk\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39512c53-2679-4ae0-a92a-31cb6374b60b",
   "metadata": {},
   "source": [
    "This is the new `rank_vectors` field type, where we will be saving our ColPali vectors. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2de5872d-b372-40fe-85c5-111b9f9fa6c8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-02T09:16:41.682532Z",
     "start_time": "2025-03-02T09:14:22.828689Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] Index 'searchlabs-colpali' already exists.\n"
     ]
    }
   ],
   "source": [
    "from dotenv import load_dotenv\n",
    "from elasticsearch import Elasticsearch\n",
    "\n",
    "load_dotenv(\"elastic.env\")\n",
    "\n",
    "ELASTIC_API_KEY = os.getenv(\"ELASTIC_API_KEY\")\n",
    "ELASTIC_HOST = os.getenv(\"ELASTIC_HOST\")\n",
    "INDEX_NAME = \"searchlabs-colpali\"\n",
    "\n",
    "es = Elasticsearch(ELASTIC_HOST, api_key=ELASTIC_API_KEY)\n",
    "\n",
    "mappings = {\"mappings\": {\"properties\": {\"col_pali_vectors\": {\"type\": \"rank_vectors\"}}}}\n",
    "\n",
    "if not es.indices.exists(index=INDEX_NAME):\n",
    "    print(f\"[INFO] Creating index: {INDEX_NAME}\")\n",
    "    es.indices.create(index=INDEX_NAME, body=mappings)\n",
    "else:\n",
    "    print(f\"[INFO] Index '{INDEX_NAME}' already exists.\")\n",
    "\n",
    "\n",
    "def index_document(es_client, index, doc_id, document, retries=10, initial_backoff=1):\n",
    "    for attempt in range(1, retries + 1):\n",
    "        try:\n",
    "            return es_client.index(index=index, id=doc_id, document=document)\n",
    "        except Exception as e:\n",
    "            if attempt < retries:\n",
    "                wait_time = initial_backoff * (2 ** (attempt - 1))\n",
    "                print(f\"[WARN] Failed to index {doc_id} (attempt {attempt}): {e}\")\n",
    "                time.sleep(wait_time)\n",
    "            else:\n",
    "                print(f\"Failed to index {doc_id} after {retries} attempts: {e}\")\n",
    "                raise"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "697f7b77-eb8f-430b-af3d-4a618c5cf086",
   "metadata": {},
   "source": [
    "Load all images back from disk, create the vectors for them and index them into Elasticsearch. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "adb4b59c-b36f-44d9-bee9-d20457630330",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-02T09:16:41.682771Z",
     "start_time": "2025-03-02T09:14:24.511339Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2ee853046b1b4929b0784c41221ba03f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Index documents:   0%|          | 0/500 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "with open(\"col_pali_vectors.pkl\", \"rb\") as f:\n",
    "    file_to_multi_vectors = pickle.load(f)\n",
    "\n",
    "for file_name, vectors in tqdm(file_to_multi_vectors.items(), desc=\"Index documents\"):\n",
    "    if es.exists(index=INDEX_NAME, id=file_name):\n",
    "        continue\n",
    "\n",
    "    index_document(\n",
    "        es_client=es,\n",
    "        index=INDEX_NAME,\n",
    "        doc_id=file_name,\n",
    "        document={\"col_pali_vectors\": vectors},\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b556999-06b7-4856-a6bb-d72ad4929f62",
   "metadata": {},
   "source": [
    "Use the new `maxSimDotProduct` function to calculate the similarity between our query and the image vectors in Elasticsearch. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8e322b23-b4bc-409d-9e00-2dab93f6a295",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-02T09:16:41.683169Z",
     "start_time": "2025-03-02T09:14:24.521130Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'><img src=\"searchlabs-colpali/image_104.jpg\" alt=\"image_104.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_3.jpg\" alt=\"image_3.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_2.jpg\" alt=\"image_2.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_12.jpg\" alt=\"image_12.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_92.jpg\" alt=\"image_92.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.display import display, HTML\n",
    "import os\n",
    "\n",
    "query = \"What do companies use for recruiting?\"\n",
    "es_query = {\n",
    "    \"_source\": False,\n",
    "    \"query\": {\n",
    "        \"script_score\": {\n",
    "            \"query\": {\"match_all\": {}},\n",
    "            \"script\": {\n",
    "                \"source\": \"maxSimDotProduct(params.query_vector, 'col_pali_vectors')\",\n",
    "                \"params\": {\"query_vector\": create_col_pali_query_vectors(query)},\n",
    "            },\n",
    "        }\n",
    "    },\n",
    "    \"size\": 5,\n",
    "}\n",
    "\n",
    "results = es.search(index=INDEX_NAME, body=es_query)\n",
    "image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n",
    "\n",
    "html = \"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'>\"\n",
    "for image_id in image_ids:\n",
    "    image_path = os.path.join(DOCUMENT_DIR, image_id)\n",
    "    html += f'<img src=\"{image_path}\" alt=\"{image_id}\" style=\"max-width:300px; height:auto; margin:10px;\">'\n",
    "html += \"</div>\"\n",
    "\n",
    "display(HTML(html))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16997bc1-ea8d-413b-a312-00f08fca1d0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We kill the kernel forcefully to free up the memory from the ColPali model.\n",
    "print(\"Shutting down the kernel to free memory...\")\n",
    "import os\n",
    "\n",
    "os._exit(0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dependecy-test-colpali-blog",
   "language": "python",
   "name": "dependecy-test-colpali-blog"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
